import os
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
import torch.utils.data as torch_data_util
from nlu_model.sim.model_pytorch.textcnn import TextCNN
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
import numpy as np
from loguru import logger
from tqdm import trange, tqdm

def parse_net_result(predict_result_1, predict_result_2):
    scores = []
    labels = []
    for idx in range(len(predict_result_1)):
        score = cosine_similarity(predict_result_1[idx], predict_result_2[idx])
        label = (1 if score >= 0 else -1)
        scores.append(score)
        labels.append(label)
    return labels, scores

def cosine_similarity(x, y, norm=False):
    """ 计算两个向量x和y的余弦相似度 """
    assert len(x) == len(y), "len(x) != len(y)"
    x = x.detach().numpy()
    y = y.detach().numpy()
    zero_list = [0] * len(x)
    if any(x == zero_list) or any(y == zero_list):
        return float(1) if any(x == y) else float(0)

    # # method 1
    # res = np.array([[x[i] * y[i], x[i] * x[i], y[i] * y[i]] for i in range(len(x))])
    # cos = sum(res[:, 0]) / (np.sqrt(sum(res[:, 1])) * np.sqrt(sum(res[:, 2])))

    # method 2
    # cos = bit_product_sum(x, y) / (np.sqrt(bit_product_sum(x, x)) * np.sqrt(bit_product_sum(y, y)))

    # method 3
    dot_product, square_sum_x, square_sum_y = 0, 0, 0
    for i in range(len(x)):
        dot_product += x[i] * y[i]
        square_sum_x += x[i] * x[i]
        square_sum_y += y[i] * y[i]
    cos = dot_product / (np.sqrt(square_sum_x) * np.sqrt(square_sum_y))

    return 0.5 * cos + 0.5 if norm else cos  # 归一化到[0, 1]区间内

class TrainModelPipeline():
    def __init__(self, config):
        self.config = config
        self.net = self.choose_model()
        print(self.net)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def choose_model(self):
        return TextCNN(self.config["MODEL_CONF"])

    def call_train(self, x_train_1, x_train_2, y_train):
        # 数据batch处理
        x_train_1 = [np.array(sent) for sent in x_train_1]
        x_train_2 = [np.array(sent) for sent in x_train_2]
        torch_dataset = sim_data(x_train_1, x_train_2, y_train)
        dataLoader = torch_data_util.DataLoader(
            dataset=torch_dataset,
            batch_size=self.config["batch_size"],             # 每批提取的数量
            shuffle=True,                                     # 要不要打乱数据（打乱比较好）
            num_workers=0                                     # 多少线程来读取数据
        )

        # 训练配置
        logger.info(self.net.parameters())
        paras = list(self.net.parameters())
        for num,para in enumerate(paras):
            print('number:',num)
            print(para)
            print('_____________________________')
        optimizer = torch.optim.Adam(self.net.parameters(), lr=0.001)
        criterion = nn.CosineEmbeddingLoss(reduce=True, size_average=True, margin=0.1)

        # 开始训练
        for epoch in range(self.config["epoch"]):
            with trange(len(list(dataLoader))) as t:
                for i, (sentences_1, sentences_2, labels) in enumerate(dataLoader):
                    t.set_description("EPOCH %s" % (epoch + 1))
                    optimizer.zero_grad()
                    sentences_1 = sentences_1.type(torch.LongTensor)
                    sentences_2 = sentences_2.type(torch.LongTensor)
                    out_1 = self.net(sentences_1)
                    out_2 = self.net(sentences_2)
                    labels = labels.type(torch.LongTensor)
                    loss = criterion(out_1, out_2, labels)
                    loss.backward(loss.clone().detach())
                    optimizer.step()

                    # data = str(epoch + 1) + ' ' + str(i + 1) + ' ' + str(loss.item()) + '\n'
                    # logger.info(type(loss))
                    t.set_postfix(loss=loss.item(),batch_num=i + 1)
                    t.update(1)
        logger.info(self.net.parameters())
        paras = list(self.net.parameters())
        for num,para in enumerate(paras):
            print('number:',num)
            print(para)
            print('_____________________________')

    def call_evaluate(self, x_test_1, x_test_2, y_test):
        # 效果测试
        sentences_1 = torch.from_numpy(np.array(x_test_1))
        sentences_2 = torch.from_numpy(np.array(x_test_2))
        predict_1 = self.net(sentences_1)
        predict_2 = self.net(sentences_2)
        label_pres, scores = parse_net_result(predict_1, predict_2)

        print("auc")
        print(roc_auc_score(y_test, scores))
        print(y_test[:10])
        print(scores[:10])
        print("confusion_matrix")
        print(confusion_matrix(y_test, label_pres))
        print("classification_report")
        print(classification_report(y_test, label_pres))


class sim_data(torch_data_util.Dataset):
    def __init__(self, x_train_1, x_train_2, y_train):
        self.x_train_1 = x_train_1
        self.x_train_2 = x_train_2
        self.y_train = y_train

    def __len__(self):
        return len(self.x_train_1)

    def __getitem__(self, idx):
        return self.x_train_1[idx], self.x_train_2[idx], self.y_train[idx]